#!/usr/bin/env python3

import sys
import json

def get_mem_info(filename):
    config = json.loads(open(filename, 'r').read())
    n_tiles = len([t for t in config['cpus'].keys() if t.startswith('cpu@')])
    mem_size = config['memory@100000000']['reg'][0]['size']
    return n_tiles, mem_size

class MemInitGen:
    NORMAL = 'NORMAL'
    MODULE = 'MODULE'
    MEM_INIT = 'MEM INIT'

    def __init__(self, module, bin_file, n_tiles, mem_size):
        self.MODULE_START = 'module ' + module
        self.MODULE_END = 'endmodule'
        self.MEM_INIT_START = '`ifdef RANDOMIZE_MEM_INIT'
        self.MEM_INIT_END = '`endif'
        self.state = MemInitGen.NORMAL
        self.bin_file = bin_file
        self.n_tiles = n_tiles
        self.mem_size = mem_size
        self.__found_module = False
        self.__emitted = False

    def transfer(self, line):
        if self.state == MemInitGen.NORMAL:
            if line.startswith(self.MODULE_START):
                self.__found_module = True
                self.state = MemInitGen.MODULE
            print(line)
        elif self.state == MemInitGen.MODULE:
            if line.startswith(self.MEM_INIT_START):
                self.__emitted = True
                self.state = MemInitGen.MEM_INIT
                print('initial begin')
                for i in range(self.n_tiles):
                    print('$readmemh("{0}", ram, \'h{1});'
                            .format(self.bin_file, hex(self.mem_size // self.n_tiles // 8 * i)[2:]));
                    print('$readmemh("{0}", ram, \'h{1});'
                            .format('c' + str(i) + '.dtb.txt', hex(self.mem_size // self.n_tiles // 8 * i + 1)[2:]));
                print('end')
            elif line.startswith(self.MODULE_END):
                self.state = MemInitGen.NORMAL
                print(line)
            else:
                print(line)
        elif self.state == MemInitGen.MEM_INIT:
            if line.startswith(self.MEM_INIT_END):
                self.state = MemInitGen.MODULE

    def found_module():
        return self.__found_module

    def emitted():
        return self.__emitted

if __name__ == '__main__':
    if len(sys.argv) < 5:
        sys.exit(sys.argv[0] + ' <input-file> <module-name> <init-file> <rocket-config.json>')

    n_tiles, mem_size = get_mem_info(sys.argv[4])
    mem_init_gen = MemInitGen(sys.argv[2], sys.argv[3], n_tiles, mem_size)
    for line in open(sys.argv[1]):
        mem_init_gen.transfer(line.strip())

    if not mem_init_gen.found_module:
        sys.exit('Module not found')
    if not mem_init_gen.emitted:
        sys.exit('Memory not initialized')
